{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "# 如果当前代码文件运行测试需要加入修改路径，避免出现后导包问题\n",
    "BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))\n",
    "sys.path.insert(0, os.path.join(BASE_DIR))\n",
    "\n",
    "PYSPARK_PYTHON = \"/miniconda2/envs/reco_sys/bin/python\"\n",
    "# 当存在多个版本时，不指定很可能会导致出错\n",
    "os.environ[\"PYSPARK_PYTHON\"] = PYSPARK_PYTHON\n",
    "os.environ[\"PYSPARK_DRIVER_PYTHON\"] = PYSPARK_PYTHON\n",
    "\n",
    "from pyspark.ml.feature import OneHotEncoder\n",
    "from pyspark.ml.feature import StringIndexer\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.sql.types import *\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "from pyspark.ml.classification import LogisticRegressionModel\n",
    "from offline import SparkSessionBase\n",
    "\n",
    "class CtrLogisticRegression(SparkSessionBase):\n",
    "\n",
    "    SPARK_APP_NAME = \"ctrLogisticRegression\"\n",
    "    ENABLE_HIVE_SUPPORT = True\n",
    "\n",
    "    def __init__(self):\n",
    "\n",
    "        self.spark = self._create_spark_hbase()\n",
    "\n",
    "ctr = CtrLogisticRegression()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#- (1)进行行为日志数据读取\n",
    "ctr.spark.sql('use profile')\n",
    "user_article_basic = ctr.spark.sql(\"select * from user_article_basic\").select(['user_id', 'article_id', 'clicked'])\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+-------------------+-------+\n",
      "|            user_id|         article_id|clicked|\n",
      "+-------------------+-------------------+-------+\n",
      "|1105045287866466304|              14225|  false|\n",
      "|1106476833370537984|              14208|  false|\n",
      "|1109980466942836736|              19233|  false|\n",
      "|1109980466942836736|              44737|  false|\n",
      "|1109993249109442560|              17283|  false|\n",
      "|1111189494544990208|              19322|  false|\n",
      "|1111524501104885760|              44161|  false|\n",
      "|1112727762809913344|              18172|   true|\n",
      "|1113020831425888256|1112592065390182400|  false|\n",
      "|1114863735962337280|              17665|  false|\n",
      "|1114863741448486912|              14208|  false|\n",
      "|1114863751909081088|              13751|  false|\n",
      "|1114863846486441984|              17940|  false|\n",
      "|1114863941936218112|              15196|  false|\n",
      "|1114863998437687296|              19233|  false|\n",
      "|1114864164158832640|             141431|  false|\n",
      "|1114864237131333632|              13797|  false|\n",
      "|1114864354622177280|             134812|  false|\n",
      "|1115089292662669312|1112608068731928576|  false|\n",
      "|1115534909935452160|              18156|  false|\n",
      "+-------------------+-------------------+-------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "user_article_basic.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+------+--------+--------------------+\n",
      "|             user_id|gender|birthday|     article_partial|\n",
      "+--------------------+------+--------+--------------------+\n",
      "|              user:1|  null|     0.0|Map(18:vars -> 0....|\n",
      "|             user:10|  null|     0.0|Map(18:tp2 -> 0.2...|\n",
      "|             user:11|  null|     0.0|               Map()|\n",
      "|user:110249052282...|  null|     0.0|               Map()|\n",
      "|user:110319567345...|  null|    null|Map(18:Animal -> ...|\n",
      "|user:110504528786...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110509388310...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110510518565...|  null|    null|Map(18:SHOldboySt...|\n",
      "|user:110639618314...|  null|    null|Map(18:tp2 -> 0.2...|\n",
      "|user:110647320376...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110647683337...|  null|    null|Map(18:text -> 1....|\n",
      "|user:110826490119...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110997636345...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110997980510...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110998046694...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110998427383...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110999324910...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110999459420...|  null|    null|Map(18:tp2 -> 0.2...|\n",
      "|user:110999526437...|  null|    null|Map(18:text -> 0....|\n",
      "|user:110999568377...|  null|    null|Map(18:text -> 0....|\n",
      "+--------------------+------+--------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#- (2)用户画像读取处理与日志数据合并\n",
    "user_profile_hbase = ctr.spark.sql(\"select user_id, information.gender, information.birthday, article_partial from user_profile_hbase\")\n",
    "user_profile_hbase.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对于用户ID做一个处理，取出前面的user字符串\n",
    "def deal_with_user_id(row):\n",
    "    return int(row.user_id.split(':')[1]), row.gender, row.birthday, row.article_partial\n",
    "\n",
    "# 错误\n",
    "# user_profile_hbase = user_profile_hbase.rdd.map(deal_with_user_id).toDF(['user_id', 'gender', 'birthday', 'article_partial'])\n",
    "user_profile = user_profile_hbase.rdd.map(deal_with_user_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "_schema = StructType([\n",
    "    StructField('user_id', LongType()),\n",
    "    StructField('gender', BooleanType()),\n",
    "    StructField('birthday', DoubleType()),\n",
    "    StructField('article_partial', MapType(StringType(), DoubleType()))\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_profile_hbase = ctr.spark.createDataFrame(user_profile, schema=_schema).drop('gender').drop('birthday')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[user_id: bigint, article_partial: map<string,double>]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "user_profile_hbase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = user_article_basic.join(user_profile_hbase, on=['user_id'], how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+----------+-------+--------------------+\n",
      "|            user_id|article_id|clicked|     article_partial|\n",
      "+-------------------+----------+-------+--------------------+\n",
      "|1106473203766657024|     16005|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     17665|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     44664|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     44386|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     14335|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     13778|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     13039|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     13648|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     17304|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     19233|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     44466|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     18795|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|    134812|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     13357|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     19171|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     44104|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     13340|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     14225|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     44739|  false|Map(18:text -> 0....|\n",
      "|1106473203766657024|     19016|  false|Map(18:text -> 0....|\n",
      "+-------------------+----------+-------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "#- (3)文章频道与向量读取合并，删除无用的特征，合并文章画像的权重特征\n",
    "ctr.spark.sql(\"use article\")\n",
    "article_vector = ctr.spark.sql(\"select * from article_vector\")\n",
    "train = train.join(article_vector, on=['article_id'], how='left')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-------------------+-------+--------------------+----------+--------------------+\n",
      "|article_id|            user_id|clicked|     article_partial|channel_id|       articlevector|\n",
      "+----------+-------------------+-------+--------------------+----------+--------------------+\n",
      "|     13401|1114864237131333632|  false|Map(18:vars -> 0....|        18|[0.06157120217893...|\n",
      "|     13401|                 10|  false|Map(18:tp2 -> 0.2...|        18|[0.06157120217893...|\n",
      "|     13401|1106396183141548032|  false|Map(18:tp2 -> 0.2...|        18|[0.06157120217893...|\n",
      "|     13401|1109994594201763840|  false|Map(18:tp2 -> 0.2...|        18|[0.06157120217893...|\n",
      "|     14805|1106473203766657024|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1113049054452908032|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1114863751909081088|   true|Map(18:text -> 2....|        18|[0.11028526511434...|\n",
      "|     14805|1115534909935452160|  false|Map(18:text -> 2....|        18|[0.11028526511434...|\n",
      "|     14805|1103195673450250240|  false|Map(18:Animal -> ...|        18|[0.11028526511434...|\n",
      "|     14805|1105045287866466304|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1114864237131333632|  false|Map(18:vars -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1109995264376045568|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1111524501104885760|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1105105185656537088|  false|Map(18:SHOldboySt...|        18|[0.11028526511434...|\n",
      "|     14805|1114864128259784704|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1114864233264185344|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1115436666438287360|  false|Map(18:text -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1114863846486441984|  false|Map(18:vars -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1115089292662669312|  false|Map(18:vars -> 0....|        18|[0.11028526511434...|\n",
      "|     14805|1114863902073552896|  false|Map(18:Animal -> ...|        18|[0.11028526511434...|\n",
      "+----------+-------------------+-------+--------------------+----------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文章画像\n",
    "ctr.spark.sql(\"use article\")\n",
    "article_profile = ctr.spark.sql(\"select article_id, keywords from article_profile\")\n",
    "# 处理文章权重\n",
    "def get_article_weights(row):\n",
    "    \n",
    "    try:\n",
    "        weights = sorted(row.keywords.values())[:10]\n",
    "    except Exception as e:\n",
    "        # 给定异常默认值\n",
    "        weights = [0.0] * 10\n",
    "    \n",
    "    return row.article_id, weights\n",
    "\n",
    "article_profile = article_profile.rdd.map(get_article_weights).toDF(['article_id', 'article_weights'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# article_profile\n",
    "train = train.join(article_profile, on=['article_id'], how='left')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-------------------+-------+--------------------+----------+--------------------+--------------------+\n",
      "|article_id|            user_id|clicked|     article_partial|channel_id|       articlevector|     article_weights|\n",
      "+----------+-------------------+-------+--------------------+----------+--------------------+--------------------+\n",
      "|     13401|1114864237131333632|  false|Map(18:vars -> 0....|        18|[0.06157120217893...|[0.08196639249252...|\n",
      "|     13401|                 10|  false|Map(18:tp2 -> 0.2...|        18|[0.06157120217893...|[0.08196639249252...|\n",
      "|     13401|1106396183141548032|  false|Map(18:tp2 -> 0.2...|        18|[0.06157120217893...|[0.08196639249252...|\n",
      "|     13401|1109994594201763840|  false|Map(18:tp2 -> 0.2...|        18|[0.06157120217893...|[0.08196639249252...|\n",
      "|     14805|1106473203766657024|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1113049054452908032|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1114863751909081088|   true|Map(18:text -> 2....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1115534909935452160|  false|Map(18:text -> 2....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1103195673450250240|  false|Map(18:Animal -> ...|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1105045287866466304|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1114864237131333632|  false|Map(18:vars -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1109995264376045568|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1111524501104885760|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1105105185656537088|  false|Map(18:SHOldboySt...|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1114864128259784704|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1114864233264185344|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1115436666438287360|  false|Map(18:text -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1114863846486441984|  false|Map(18:vars -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1115089292662669312|  false|Map(18:vars -> 0....|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "|     14805|1114863902073552896|  false|Map(18:Animal -> ...|        18|[0.11028526511434...|[0.15069781969741...|\n",
      "+----------+-------------------+-------+--------------------+----------+--------------------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[article_id: bigint, user_id: bigint, clicked: boolean, article_partial: map<string,double>, channel_id: int, articlevector: array<double>, article_weights: array<double>]"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "# - (4)进行用户的权重特征筛选处理，类型处理\n",
    "train = train.dropna()\n",
    "\n",
    "columns = ['article_id', 'user_id', 'channel_id', 'articlevector', 'user_weights', 'article_weights', 'clicked']\n",
    "# array --->vecoter\n",
    "def get_user_weights(row):\n",
    "    \n",
    "    # 取出所有对应particle平道的关键词权重（用户）\n",
    "    from pyspark.ml.linalg import Vectors\n",
    "    try:\n",
    "        weights = sorted([row.article_partial[key] for key in \n",
    "                          row.article_partial.keys() if key.split(':')[0] == str(row.channel_id)])[:10]\n",
    "    except Exception as e:\n",
    "        weights = [0.0] * 10\n",
    "    \n",
    "    return row.article_id, row.user_id, row.channel_id, Vectors.dense(row.articlevector), Vectors.dense(weights), Vectors.dense(row.article_weights),int(row.clicked) \n",
    "\n",
    "train_1 = train.rdd.map(get_user_weights).toDF(columns)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[article_id: bigint, user_id: bigint, channel_id: bigint, articlevector: vector, user_weights: vector, article_weights: vector, clicked: bigint]"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用收集特征dao features\n",
    "train_vecrsion_two = VectorAssembler().setInputCols(columns[2:6]).setOutputCol('features').transform(train_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-------------------+----------+--------------------+--------------------+--------------------+-------+--------------------+\n",
      "|article_id|            user_id|channel_id|       articlevector|        user_weights|     article_weights|clicked|            features|\n",
      "+----------+-------------------+----------+--------------------+--------------------+--------------------+-------+--------------------+\n",
      "|     13401|1114864237131333632|        18|[0.06157120217893...|[0.32473420471378...|[0.08196639249252...|      0|[18.0,0.061571202...|\n",
      "|     13401|                 10|        18|[0.06157120217893...|[0.21215332784742...|[0.08196639249252...|      0|[18.0,0.061571202...|\n",
      "|     13401|1106396183141548032|        18|[0.06157120217893...|[0.22553064631951...|[0.08196639249252...|      0|[18.0,0.061571202...|\n",
      "|     13401|1109994594201763840|        18|[0.06157120217893...|[0.24443647588626...|[0.08196639249252...|      0|[18.0,0.061571202...|\n",
      "|     14805|1106473203766657024|        18|[0.11028526511434...|[0.22553064631951...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1113049054452908032|        18|[0.11028526511434...|[0.28050889359956...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1114863751909081088|        18|[0.11028526511434...|[0.32473420471378...|[0.15069781969741...|      1|[18.0,0.110285265...|\n",
      "|     14805|1115534909935452160|        18|[0.11028526511434...|[0.35819704778381...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1103195673450250240|        18|[0.11028526511434...|[0.21442838668808...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1105045287866466304|        18|[0.11028526511434...|[0.21952219380422...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1114864237131333632|        18|[0.11028526511434...|[0.32473420471378...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1109995264376045568|        18|[0.11028526511434...|[0.24443647588626...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1111524501104885760|        18|[0.11028526511434...|[0.26087773109487...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1105105185656537088|        18|[0.11028526511434...|[0.21952219380422...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1114864128259784704|        18|[0.11028526511434...|[0.32473420471378...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1114864233264185344|        18|[0.11028526511434...|[0.32473420471378...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1115436666438287360|        18|[0.11028526511434...|[0.35819704778381...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1114863846486441984|        18|[0.11028526511434...|[0.32473420471378...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1115089292662669312|        18|[0.11028526511434...|[0.33945366606672...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "|     14805|1114863902073552896|        18|[0.11028526511434...|[0.32473420471378...|[0.15069781969741...|      0|[18.0,0.110285265...|\n",
      "+----------+-------------------+----------+--------------------+--------------------+--------------------+-------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# features 121值, 13, 18,       1,2,3,4,5,6....25\n",
    "# 25 + 100 + 10 + 10 = 145个特征\n",
    "train_vecrsion_two.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lr = LogisticRegression()\n",
    "# model = lr.setLabelCol(\"clicked\").setFeaturesCol(\"features\").fit(train_vecrsion_two)\n",
    "# model.save(\"hdfs://hadoop-master:9000/headlines/models/test_ctr.obj\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# online_model = LogisticRegressionModel.load(\"hdfs://hadoop-master:9000/headlines/models/logistic_ctr_model.obj\")\n",
    "\n",
    "# res_transfrom = online_model.transform(train_version_two)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vector_to_double(row):\n",
    "    return float(row.clicked), float(row.probability[1]) \n",
    "\n",
    "score_label = res_transfrom.select([\"clicked\", \"probability\"]).rdd.map(vector_to_double)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clicked 目标值（真实）\n",
    "# probability: [不点击的概率， 点击的概率]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造样本\n",
    "ctr.spark.sql(\"use profile\")\n",
    "\n",
    "user_profile_hbase = ctr.spark.sql(\n",
    "    \"select user_id, information.birthday, information.gender, article_partial, env from user_profile_hbase\")\n",
    "\n",
    "# 特征工程处理\n",
    "# 抛弃获取值少的特征\n",
    "user_profile_hbase = user_profile_hbase.drop('env', 'birthday', 'gender')\n",
    "\n",
    "def get_user_id(row):\n",
    "    return int(row.user_id.split(\":\")[1]), row.article_partial\n",
    "\n",
    "user_profile_hbase_temp = user_profile_hbase.rdd.map(get_user_id)\n",
    "\n",
    "from pyspark.sql.types import *\n",
    "\n",
    "_schema = StructType([\n",
    "    StructField(\"user_id\", LongType()),\n",
    "    StructField(\"weights\", MapType(StringType(), DoubleType()))\n",
    "])\n",
    "\n",
    "user_profile_hbase_schema = ctr.spark.createDataFrame(user_profile_hbase_temp, schema=_schema)\n",
    "\n",
    "def frature_preprocess(row):\n",
    "\n",
    "    from pyspark.ml.linalg import Vectors\n",
    "\n",
    "    channel_weights = []\n",
    "    for i in range(1, 26):\n",
    "        try:\n",
    "            _res = sorted([row.weights[key] for key\n",
    "                           in row.weights.keys() if key.split(':')[0] == str(i)])[:10]\n",
    "            channel_weights.append(_res)\n",
    "        except:\n",
    "            channel_weights.append([0.0] * 10)\n",
    "\n",
    "    return row.user_id, channel_weights\n",
    "\n",
    "res = user_profile_hbase_schema.rdd.map(frature_preprocess).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "72\n"
     ]
    }
   ],
   "source": [
    "print(len(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (10,\n",
    "#   [[],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846,\n",
    "#     0.21215332784742846],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    [],\n",
    "#    []])\n",
    "import happybase\n",
    "# 批量插入Hbase数据库中\n",
    "pool = happybase.ConnectionPool(size=10, host='hadoop-master', port=9090)\n",
    "with pool.connection() as conn:\n",
    "    ctr_feature_user = conn.table('ctr_feature_user')\n",
    "    with ctr_feature_user.batch(transaction=True) as b:\n",
    "        for i in range(len(res)):\n",
    "            for j in range(25):\n",
    "                # j 0~~~24\n",
    "                b.put('{}'.format(res[i][0]).encode(), {'channel:{}'.format(j + 1).encode(): str(res[i][1][j]).encode()})\n",
    "    conn.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
